{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Augmented Normalizing Flow based on Real NVP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required packages\n",
    "import torch\n",
    "import numpy as np\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set up model\n",
    "\n",
    "# Define flows\n",
    "K = 32\n",
    "torch.manual_seed(0)\n",
    "\n",
    "latent_size = 4\n",
    "b = torch.Tensor([1] * (latent_size // 2) + [0] * (latent_size // 2))\n",
    "flows = []\n",
    "for i in range(K):\n",
    "    s = nf.nets.MLP([latent_size, 4 * latent_size, latent_size], init_zeros=True)\n",
    "    t = nf.nets.MLP([latent_size, 4 * latent_size, latent_size], init_zeros=True)\n",
    "    if i % 2 == 0:\n",
    "        flows += [nf.flows.MaskedAffineFlow(b, t, s)]\n",
    "    else:\n",
    "        flows += [nf.flows.MaskedAffineFlow(1 - b, t, s)]\n",
    "    flows += [nf.flows.ActNorm(latent_size)]\n",
    "\n",
    "# Set augmented target\n",
    "target = nf.distributions.TwoIndependent(nf.distributions.TwoMoons(), \n",
    "                                         nf.distributions.DiagGaussian(2))\n",
    "# Set base distribution\n",
    "q0 = nf.distributions.DiagGaussian(4)\n",
    "\n",
    "# Construct flow model\n",
    "nfm = nf.NormalizingFlow(q0=q0, flows=flows, p=target)\n",
    "\n",
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "nfm = nfm.to(device)\n",
    "nfm = nfm.double()\n",
    "\n",
    "# Initialize ActNorm\n",
    "z, _ = nfm.sample(num_samples=2 ** 7)\n",
    "z_np = z.to('cpu').data.numpy()\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.title(\"Standard coordinates\")\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.title(\"Augmented coordinates\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot augmented target\n",
    "z = target.sample(num_samples=2 ** 16)\n",
    "z_np = z.to('cpu').data.numpy()\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.title(\"Standard coordinates\")\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.title(\"Augmented coordinates\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 20000\n",
    "num_samples = 2 * 10\n",
    "anneal_iter = 10000\n",
    "show_iter = 1000\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-4, weight_decay=1e-6)\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    loss = nfm.reverse_kld(num_samples, beta=np.min([1., 0.01 + it / anneal_iter]))\n",
    "    \n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned posterior\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        z, _ = nfm.sample(num_samples=2 ** 14)\n",
    "        z_np = z.to('cpu').data.numpy()\n",
    "\n",
    "        plt.figure(figsize=(15, 15))\n",
    "        plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "        plt.gca().set_aspect('equal', 'box')\n",
    "        plt.title(\"Standard coordinates\")\n",
    "        plt.show()\n",
    "\n",
    "        plt.figure(figsize=(15, 15))\n",
    "        plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "        plt.gca().set_aspect('equal', 'box')\n",
    "        plt.title(\"Augmented coordinates\")\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot loss\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot learned distribution\n",
    "z, _ = nfm.sample(num_samples=2 ** 16)\n",
    "z_np = z.to('cpu').data.numpy()\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.hist2d(z_np[:, 0].flatten(), z_np[:, 1].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.title(\"Standard coordinates\")\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.hist2d(z_np[:, 2].flatten(), z_np[:, 3].flatten(), (50, 50), range=[[-3, 3], [-3, 3]])\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.title(\"Augmented coordinates\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
